/*

   BLIS
   An object-based framework for developing high-performance BLAS-like
   libraries.

   Copyright (C) 2014, The University of Texas at Austin
   Copyright (C) 2020, Advanced Micro Devices, Inc.

   Redistribution and use in source and binary forms, with or without
   modification, are permitted provided that the following conditions are
   met:
    - Redistributions of source code must retain the above copyright
      notice, this list of conditions and the following disclaimer.
    - Redistributions in binary form must reproduce the above copyright
      notice, this list of conditions and the following disclaimer in the
      documentation and/or other materials provided with the distribution.
    - Neither the name(s) of the copyright holder(s) nor the names of its
      contributors may be used to endorse or promote products derived
      from this software without specific prior written permission.

   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
   "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
   HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
   SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
   LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
   DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
   THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
   (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
   OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

*/

#include "blis.h"
#include "test_libblis.h"

#define PRINT 0

// Static variables.
static char *op_str = "gemmt";
static char *o_types = "mmm";								  // a b c
static char *p_types = "uhh";								  // uploc transa transb
static thresh_t thresh[BLIS_NUM_FP_TYPES] = {{1e-04, 1e-05},  // warn, pass for s
											 {1e-04, 1e-05},  // warn, pass for c
											 {1e-13, 1e-14},  // warn, pass for d
											 {1e-13, 1e-14}}; // warn, pass for z

// Local prototypes.
void libblis_test_gemmt_deps(
	thread_data_t *tdata,
	test_params_t *params,
	test_op_t *op);

void libblis_test_gemmt_experiment(
	test_params_t *params,
	test_op_t *op,
	iface_t iface,
	char *dc_str,
	char *pc_str,
	char *sc_str,
	unsigned int p_cur,
	double *perf,
	double *resid);

void libblis_test_gemmt_impl(
	iface_t iface,
	obj_t *alpha,
	obj_t *a,
	obj_t *b,
	obj_t *beta,
	obj_t *c);

void libblis_test_gemmt_check(
	test_params_t *params,
	obj_t *alpha,
	obj_t *a,
	obj_t *b,
	obj_t *beta,
	obj_t *c,
	obj_t *c_orig,
	double *resid);

void libblis_test_gemmt_deps(
	thread_data_t *tdata,
	test_params_t *params,
	test_op_t *op)
{
	libblis_test_randv(tdata, params, &(op->ops->randv));
	libblis_test_randm(tdata, params, &(op->ops->randm));
	libblis_test_normfv(tdata, params, &(op->ops->normfv));
	libblis_test_subv(tdata, params, &(op->ops->subv));
	libblis_test_copym(tdata, params, &(op->ops->copym));
	libblis_test_gemv(tdata, params, &(op->ops->gemv));
	libblis_test_addm(tdata, params, &(op->ops->addm));
}

void libblis_test_gemmt(
	thread_data_t *tdata,
	test_params_t *params,
	test_op_t *op)
{

	// Return early if this test has already been done.
	if (libblis_test_op_is_done(op))
		return;

	// Return early if operation is disabled.
	if (libblis_test_op_is_disabled(op) ||
		libblis_test_l3_is_disabled(op))
		return;

	// Call dependencies first.
	if (TRUE)
		libblis_test_gemmt_deps(tdata, params, op);

	// Execute the test driver for each implementation requested.
	//if ( op->front_seq == ENABLE )
	{
		libblis_test_op_driver(tdata,
							   params,
							   op,
							   BLIS_TEST_SEQ_FRONT_END,
							   op_str,
							   p_types,
							   o_types,
							   thresh,
							   libblis_test_gemmt_experiment);
	}
}

void libblis_test_gemmt_experiment(
	test_params_t *params,
	test_op_t *op,
	iface_t iface,
	char *dc_str,
	char *pc_str,
	char *sc_str,
	unsigned int p_cur,
	double *perf,
	double *resid)
{
	unsigned int n_repeats = params->n_repeats;
	unsigned int i;

	double time_min = DBL_MAX;
	double time;

	num_t datatype;

	dim_t m, k;

	uplo_t uploc;
	trans_t transa, transb;

	obj_t alpha, a, b, beta;
	obj_t c, c_ref, c_org_tri, c_result_tri, c_save;

	// Use the datatype of the first char in the datatype combination string.
	bli_param_map_char_to_blis_dt(dc_str[0], &datatype);

	// Map the dimension specifier to actual dimensions.
	m = libblis_test_get_dim_from_prob_size(op->dim_spec[0], p_cur);
	k = libblis_test_get_dim_from_prob_size(op->dim_spec[1], p_cur);

	// Map parameter characters to BLIS constants.
	bli_param_map_char_to_blis_uplo(pc_str[0], &uploc);
	bli_param_map_char_to_blis_trans(pc_str[1], &transa);
	bli_param_map_char_to_blis_trans(pc_str[2], &transb);

	// Create test scalars.
	bli_obj_scalar_init_detached(datatype, &alpha);
	bli_obj_scalar_init_detached(datatype, &beta);

	// Create test operands (vectors and/or matrices).
	libblis_test_mobj_create(params, datatype, transa,
							 sc_str[1], m, k, &a);
	libblis_test_mobj_create(params, datatype, transb,
							 sc_str[2], k, m, &b);
	libblis_test_mobj_create(params, datatype, BLIS_NO_TRANSPOSE,
							 sc_str[0], m, m, &c);
	libblis_test_mobj_create(params, datatype, BLIS_NO_TRANSPOSE,
							 sc_str[0], m, m, &c_save);
	libblis_test_mobj_create(params, datatype, BLIS_NO_TRANSPOSE,
							 sc_str[0], m, m, &c_ref);
	libblis_test_mobj_create(params, datatype, BLIS_NO_TRANSPOSE,
							 sc_str[0], m, m, &c_org_tri);
	libblis_test_mobj_create(params, datatype, BLIS_NO_TRANSPOSE,
							 sc_str[0], m, m, &c_result_tri);

	// Set alpha and beta.
	if (bli_obj_is_real(&c))
	{
		bli_setsc(1.2, 0.0, &alpha);
		bli_setsc(-1.0, 0.0, &beta);
	}
	else
	{
		// For gemmt, both alpha and beta may be complex since, unlike herk,
		// C is symmetric in both the real and complex cases.
		bli_setsc(1.2, 0.5, &alpha);
		bli_setsc(-1.0, 0.5, &beta);
	}

	// Randomize A and B
	libblis_test_mobj_randomize(params, TRUE, &a);
	libblis_test_mobj_randomize(params, TRUE, &b);

	// Apply the remaining parameters.
	// We need to do this before we create the referece matrix
	bli_obj_set_conjtrans(transa, &a);
	bli_obj_set_conjtrans(transb, &b);

	// We want to create two final matrices
	// 1. Input matric c : This will be the random matrix used as input for gemmt
	//    it needs uplo settings for gemmt to decide which half to be updated
	//    for the result.
	// 2. Refernce matrix C_ref: This matrix is expected output from gemmt
	//    This matrix is constructed as explain below.
	//
	//    a. c_org_tri: This matrix contains only the original elements from c
	//       which are not updated by GEMMT operation. All other elements will be set to 0.
	//       This is constructed by performing the GEMM operation using alpha=beta = 0
	//       and setting the uplo to the uplo reqested.
	//
	//    b. c_result_tri: This matrix contains only the elementes that will be updated by gemmt
	//       This matrix is constructed by doing normal GEMM operation and converting the result
	//       to trianguler matrix, this will ensure that all other elements excpet the required
	//       uploc settings are set to 0.
	//
	//    c. Finally c_ref matrix is constucted by adding above to matrices.
	//
	//  3. GEMMT operation will be performed using a, b & c and the results will be compared
	//     with c_ref.

	//
	// Assuming that gemmt is done on lower triangle we can represent
	// this calculation as.
	//
	//                   gemmt(a,b,c) = L(gemm(a,b,c)\U(c)
	//                                = c_results_tri \ C_org_tri
	// (beta * C + alpha * A * B) \ C = ((beta * C + alpha * A * B) \ 0) \ (0\C)
	//
	// C_result_tri = lower trianlge
	// C_org_tri = strictly upper triangle.
	// "\" represents matrix divided into triangles.
	//
	// For upper triangle operations the order of lower and upper matrices in
	// these euqations will be exchanged.

	// Generate random input matrix
	libblis_test_mobj_randomize(params, TRUE, &c);

	// Create the requried copies before setting the uplo attribute
	bli_copym(&c, &c_save);
	bli_copym(&c, &c_org_tri);
	bli_copym(&c, &c_result_tri);
	bli_obj_set_uplo(uploc, &c);
	bli_obj_set_uplo(uploc, &c_save);

	// Create c_org_tri matrix using setm operation, this matrix will
	// have original values from input matrix "c" for all elements outside
	// triangle selected for GEMMT operation.
	bli_obj_set_uplo(uploc, &c_org_tri); // Set to request uplo to set all elemnts in triangle to zero
	bli_setm(&BLIS_ZERO, &c_org_tri);
	bli_obj_toggle_uplo(&c_org_tri); // Toggle uplo now so that untouched triangle is active.

	// GEMMT output is same as GEMM for the triangle selected by uplo
	// So we want to extract this triangle from complete GEMM results
	// We do this by setting the uplo and converting the results
	// to triangluer matrix.
	// Perform gemm operation on original inputs
	bli_gemm(&alpha, &a, &b, &beta, &c_result_tri);
	// Set the values in other triangle to zero by converting it to trianguler matrix
	bli_obj_set_uplo(uploc, &c_result_tri);
	bli_mktrim(&c_result_tri);

	// Now we have two matrices with opposite triangles set to zero
	// c_result_tri: It has output of GEMM in selected triangle (including diagonal)
	//               Rest of its elements are set to zero.
	// c_org_tri: It has values from orignal C matrix in the non-selected triangle
	//            Rest of the elements including diagonal are set to zero
	// The result of the GEMMT operation will be combined matrix of thse two matrics
	// So add them togher
	bli_setm(&BLIS_ZERO, &c_ref); // Both matrices we are going to add, have uplo settings
								  // Clear the destination matrix to avoid partial updates
	bli_copym(&c_org_tri, &c_ref);
	bli_addm(&c_result_tri, &c_ref);

#if PRINT
	bli_printm("c", &c, "%5.2f", "");
	bli_printm("c_org_tri", &c_org_tri, "%5.2f", "");
	bli_printm("c_result_tri", &c_result_tri, "%5.2f", "");
	bli_printm("c_ref", &c_ref, "%5.2f", "");
#endif

	// Repeat the experiment n_repeats times and record results.
	for (i = 0; i < n_repeats; ++i)
	{
		bli_copym(&c_save, &c);

		time = bli_clock();

		libblis_test_gemmt_impl(iface, &alpha, &a, &b, &beta, &c);

		time_min = bli_clock_min_diff(time_min, time);
	}

	// Estimate the performance of the best experiment repeat.
	*perf = (1.0 * m * m * k) / time_min / FLOPS_PER_UNIT_PERF;
	if (bli_obj_is_complex(&c))
		*perf *= 4.0;

	// Perform checks.
	libblis_test_gemmt_check(params, &alpha, &a, &b, &beta, &c, &c_ref, resid);

	// Zero out performance and residual if output matrix is empty.
	libblis_test_check_empty_problem(&c, perf, resid);

	// Free the test objects.
	bli_obj_free(&a);
	bli_obj_free(&b);
	bli_obj_free(&c);
	bli_obj_free(&c_ref);
	bli_obj_free(&c_org_tri);
	bli_obj_free(&c_result_tri);
	bli_obj_free(&c_save);
}

void libblis_test_gemmt_impl(
	iface_t iface,
	obj_t *alpha,
	obj_t *a,
	obj_t *b,
	obj_t *beta,
	obj_t *c)
{
	switch (iface)
	{
	case BLIS_TEST_SEQ_FRONT_END:
#if PRINT
		bli_printm("a", a, "%5.2f", "");
		bli_printm("b", b, "%5.2f", "");
		bli_printm("c Before", c, "%5.2f", "");
#endif

		bli_gemmt(alpha, a, b, beta, c);

#if PRINT
		bli_printm("c after", c, "%5.2f", "");
#endif
		break;

	default:
		libblis_test_printf_error("Invalid interface type.\n");
	}
}

void libblis_test_gemmt_check(
	test_params_t *params,
	obj_t *alpha,
	obj_t *a,
	obj_t *b,
	obj_t *beta,
	obj_t *c,
	obj_t *c_orig,
	double *resid)
{
	num_t dt = bli_obj_dt(c);
	num_t dt_real = bli_obj_dt_proj_to_real(c);

	dim_t m = bli_obj_length(c);

	obj_t norm;
	obj_t t, v, z;

	double junk;

	//
	// Pre-conditions:
	// - a is randomized.
	// - b is randomized.
	// - c is randomized with uplo set
	//
	// Note:
	// - alpha and beta should have non-zero imaginary components in the
	//   complex cases in order to more fully exercise the implementation.
	//
	// Under these conditions, we assume that the implementation for
	//
	//   C := beta * C_orig + alpha * transa(A) * transa(B)
	//
	// is functioning correctly if
	//
	//   normfv( v - z )
	//
	// is negligible, where
	//
	//   v = C * t
	//   z = C * C_reference
	//
	//

	bli_obj_scalar_init_detached(dt_real, &norm);

	bli_obj_create(dt, m, 1, 0, 0, &t);
	bli_obj_create(dt, m, 1, 0, 0, &v);
	bli_obj_create(dt, m, 1, 0, 0, &z);

	libblis_test_vobj_randomize(params, TRUE, &t);

	// Ensure result metrix has only selected triangle.
	// Calculate V = C * t
	bli_gemv(&BLIS_ONE, c, &t, &BLIS_ZERO, &v);
	bli_gemv(&BLIS_ONE, c_orig, &t, &BLIS_ZERO, &z);

#if PRINT
	bli_printm("c-gemmt", c, "%5.2f", "");
	bli_printm("c-gemm", c_orig, "%5.2f", "");
	bli_printv("v", &v, "%5.2f", "");
	bli_printv("z", &z, "%5.2f", "");
#endif

	// Find the norm
	bli_subv(&z, &v);
	bli_normfv(&v, &norm);
	bli_getsc(&norm, resid, &junk);

	bli_obj_free(&t);
	bli_obj_free(&v);
	bli_obj_free(&z);
}
